{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "import torch.optim as optim\n",
    "from efficientnet_pytorch import EfficientNet\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import multiprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if cuda is available\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "v = 0     # model version\n",
    "in_c = 2  # number of input channels\n",
    "num_c = 1 # number of classes to predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The optical flow input will look like this\n",
    "of = torch.randn(1,2,640,480)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained weights for efficientnet-b0\n"
     ]
    }
   ],
   "source": [
    "model = EfficientNet.from_pretrained(f'efficientnet-b{v}', in_channels=in_c, num_classes=num_c)\n",
    "model.to(device);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### The output of the model will look like this"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.08474025130271912"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "of = of.to(device)\n",
    "model(of).item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# directory with the optical flow images\n",
    "of_dir = '<YOUR DIRECTORY>'\n",
    "# labels as txt file\n",
    "labels_f = '<YOUR FILE>'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OFDataset(Dataset):\n",
    "    def __init__(self, of_dir, label_f):\n",
    "        self.len = len(list(Path(of_dir).glob('*.npy')))\n",
    "        self.of_dir = of_dir\n",
    "        self.label_file = open(label_f).readlines()\n",
    "    def __len__(self): return self.len\n",
    "    def __getitem__(self, idx):\n",
    "        of_array = np.load(Path(self.of_dir)/f'{idx}.npy')\n",
    "        of_tensor = torch.squeeze(torch.Tensor(of_array))\n",
    "        label = float(self.label_file[idx].split()[0])\n",
    "        return [of_tensor, label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = OFDataset(of_dir, labels_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 80% of data for training\n",
    "# 20% of data for validation\n",
    "train_split = .8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_size = len(ds)\n",
    "indices = list(range(ds_size))\n",
    "split = int(np.floor(train_split * ds_size))\n",
    "train_idx, val_idx = indices[:split], indices[split:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = ds[3]\n",
    "assert type(sample[0]) == torch.Tensor\n",
    "assert type(sample[1]) == float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sampler = SubsetRandomSampler(train_idx)\n",
    "val_sampler = SubsetRandomSampler(val_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cpu_cores = multiprocessing.cpu_count()\n",
    "cpu_cores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dl = DataLoader(ds, batch_size=8, sampler=train_sampler, num_workers=cpu_cores)\n",
    "val_dl = DataLoader(ds, batch_size=8, sampler=val_sampler, num_workers=cpu_cores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1000\n",
    "log_train_steps = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.MSELoss()\n",
    "opt = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    running_loss = 0.0\n",
    "    for i, sample in enumerate(train_dl):\n",
    "        of_tensor = sample[0].cuda()\n",
    "        label = sample[1].float().cuda()\n",
    "        opt.zero_grad()\n",
    "        pred = torch.squeeze(model(of_tensor))\n",
    "        loss = criterion(pred, label)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "    # validation\n",
    "    model.eval()\n",
    "    val_losses = []\n",
    "    with torch.no_grad():\n",
    "        for j, val_sample in enumerate(val_dl):\n",
    "            of_tensor = val_sample[0].cuda()\n",
    "            label = val_sample[1].float().cuda()\n",
    "            pred = torch.squeeze(model(of_tensor))\n",
    "            loss = criterion(pred, label)\n",
    "            val_losses.append(loss)\n",
    "        print(f'{epoch}: {sum(val_losses)/len(val_losses)}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
